package com.ruoyi.framework.aop;

import cn.hutool.core.util.StrUtil;
import com.google.common.collect.ImmutableList;
import com.ruoyi.framework.web.domain.AjaxResult;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.Objects;


/**
 * <p>
 * 限流AOP拦截器
 * </p>
 *
 * @author 杨航
 * @since 2019-03-06 12:46
 */
@Slf4j
@Aspect
public class LimitRequestInterceptor {

    @Autowired
    private RedisTemplate <String, Serializable> redisTemplate;

    @SneakyThrows
    @Around ("@annotation(limitRequestAnnotation)")
    public Object limitInterceptor (ProceedingJoinPoint proceedingJoinPoint,
            LimitRequestAnnotation limitRequestAnnotation) {
        //获取当前请求HttpServletRequest
        HttpServletRequest request = HttpUtils.getHttpServletRequest();
        Object proceed = null;

        //获取请求头
        String header = request.getHeader(SecurityConstants.FROM);

        //判断是否是内部请求 满足条件将其放行
        if (StrUtil.isNotEmpty(header) && StrUtil.equals(SecurityConstants.FROM_IN, header)) {
            proceed = proceedingJoinPoint.proceed();
            return proceed;
        }
        MethodSignature signature = (MethodSignature) proceedingJoinPoint.getSignature();
        Method method = signature.getMethod();

        // 如果annotation为空 直接放行执行程序
        if (Objects.isNull(limitRequestAnnotation)) {
            try {
                proceed = proceedingJoinPoint.proceed();
            } catch (Throwable throwable) {
                throwable.printStackTrace();
            }
            return proceed;
        }
        //        String userId = getUserId();

        LimitType limitType = limitRequestAnnotation.limitType();
        String name = limitRequestAnnotation.name();
        int limitPeriod = limitRequestAnnotation.period();
        int limitCount = limitRequestAnnotation.count();

        String key;
        switch (limitType) {
            //IP限流
            case IP:
                key = "limit:" + RequestIpUtils.getIpAddr(request);
                break;
            //客户自定义限流
            case CUSTOMER:
                String value = limitRequestAnnotation.value();
                key = "limit:" + value;
                break;
            default:
                key = "limit:" + StringUtils.upperCase(method.getName());
        }
        ImmutableList <String> keys = ImmutableList.of(StringUtils.join(limitRequestAnnotation.prefix(), key));
        try {
            String luaScript = LuaScriptUtils.buildLuaScript();
            RedisScript <Number> redisScript = new DefaultRedisScript <>(luaScript, Number.class);
            Number count = this.redisTemplate.execute(redisScript, keys, limitCount, limitPeriod);
            log.debug("Access try count is {} for name={} and key = {}", count, name, key);
            if (Objects.nonNull(count) && count.intValue() <= limitCount) {
                return proceedingJoinPoint.proceed();
            } else {
                ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
                HttpServletResponse response = attributes.getResponse();
                this.returnData(response);
            }
        } catch (Throwable e) {
            //调用异常工具类抛出异常
            //            ExtApiExceptionUtils.throwThrowableException(e);
        }
        return null;
    }

    /**
     * 返回数据
     *
     * @param response
     *
     * @throws IOException
     */
    public void returnData (HttpServletResponse response) throws IOException {
        response.setCharacterEncoding("UTF-8");
        response.setContentType("application/json; charset=utf-8");
        AjaxResult ajaxResult = AjaxResult.error("速度太快了，请稍后重试");
        //这里传提示语可以改成自己项目的返回数据封装的类
        response.getWriter().println(ajaxResult);
        return;
    }
}
